/**
 * Copyright (c) 2024 Huawei Technologies Co., Ltd.
 * This file is a part of the CANN Open Software.
 * Licensed under CANN Open Software License Agreement Version 1.0 (the "License").
 * Please refer to the License for details. You may not use this file except in compliance with the License.
 * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
 * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
 * See LICENSE in the root of the software repository for the full text of the License.
 */

/*!
 * \file topk.h
 * \brief
 */
#ifndef LIB_SORT_TOPK_H
#define LIB_SORT_TOPK_H

#if __CCE_AICORE__ >= 200

#include "kernel_tensor.h"
#include "kernel_utils.h"
#include "kernel_tiling/kernel_tiling.h"
#include "../../impl/sort/topk/topk_common_utils.h"
#include "../../impl/sort/topk/topk_common_impl.h"

#if __CCE_AICORE__ == 220
#include "../../impl/sort/topk/topk_v220_impl.h"
#elif __CCE_AICORE__ == 200
#include "../../impl/sort/topk/topk_v200_impl.h"
#endif

namespace AscendC {
#pragma begin_pipe(V)
/*
 * @ingroup TopK
 * @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
 * @tparam T: Data type to be sorted, half or float.
 * @tparam isInitIndex: Whether to transfer the index of the input data.
                        If the value is true, srcIndexLocal is the index of the input data.
                        If the value is false, the index is generated by the Topk API.
 * @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
                        If the value is true, enable the function. If the value is false, disable the function.
                        In normal mode, isHasfinish can be set to true or false.
                        In small mode, isHasfinish can only be set to false.
 * @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
                       This parameter is reserved. Use the default value false.
 * @tparam topkMode: Normal mode or small mode,
                     Small mode is recommended when the inner axis length is 32. Performance will be high.
 * @param [out] dstValueLocal: Used to store k sorted values.
 * @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
 * @param [in] srcLocal: Input data to hold values to be sorted.
 * @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
 * @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
 * @param [in] tmpLocal: Temporary space for storing intermediate variables during internal calculation.
 * @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
 * @param [in] tilling: Tiling information required for TopK calculation.
 * @param [in] topKInfo: Shape information of srcLocal.
 * @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
                          and the value false indicates ascending order.
 */
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
    enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
__aicore__ inline void TopK(const LocalTensor<T> &dstValueLocal, const LocalTensor<int32_t> &dstIndexLocal,
    const LocalTensor<T> &srcLocal, const LocalTensor<int32_t> &srcIndexLocal, const LocalTensor<bool> &finishLocal,
    const LocalTensor<uint8_t> &tmpLocal, const int32_t k, const TopkTiling &tilling, const TopKInfo &topKInfo,
    const bool isLargest = true)
{
    // Only for AI Vector Core.
    if ASCEND_IS_AIC {
        return;
    }

#if ASCENDC_CPU_DEBUG
    TopkInputCheck<T, isInitIndex, topkMode>(k, topKInfo);
#endif

    if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
        TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal,
            finishLocal, tmpLocal, k, tilling, topKInfo, isLargest);
    }
    if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
        TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal,
            finishLocal, tmpLocal, k, tilling, topKInfo, isLargest);
    }
}

/*
 * @ingroup TopK
 * @brief Get the top k maximum or minimum values and their corresponding indices of the last dimension.
 * @tparam T: Data type to be sorted, half or float.
 * @tparam isInitIndex: Whether to transfer the index of the input data.
                        If the value is true, srcIndexLocal is the index of the input data.
                        If the value is false, the index is generated by the Topk API.
 * @tparam isHasfinish: The isHasfinish parameter is used to specify that the sorting of some rows is invalid.
                        If the value is true, enable the function. If the value is false, disable the function.
                        In normal mode, isHasfinish can be set to true or false.
                        In small mode, isHasfinish can only be set to false.
 * @tparam isReuseSrc: Whether temporary variables can reuse the input memory.
                       This parameter is reserved. Use the default value false.
 * @tparam topkMode: Normal mode or small mode,
                     Small mode is recommended when the inner axis length is 32. Performance will be high.
 * @param [out] dstValueLocal: Used to store k sorted values.
 * @param [out] dstIndexLocal: Used to store indexes corresponding to sorted k values.
 * @param [in] srcLocal: Input data to hold values to be sorted.
 * @param [in] srcIndexLocal: The input data is used to store the index corresponding to the value of srcLocal.
 * @param [in] finishLocal: Used to specify that the sort of some rows is an invalid sort with shape of (outter, 1).
 * @param [in] k: Obtain the first k maximum or minimum values and their corresponding indexes.
 * @param [in] tilling: Tiling information required for TopK calculation.
 * @param [in] topKInfo: Shape information of srcLocal.
 * @param [in] isLargest: Descending or ascending order. The value true indicates descending order,
                          and the value false indicates ascending order.
 */
template <typename T, bool isInitIndex = false, bool isHasfinish = false, bool isReuseSrc = false,
    enum TopKMode topkMode = TopKMode::TOPK_NORMAL>
__aicore__ inline void TopK(const LocalTensor<T> &dstValueLocal, const LocalTensor<int32_t> &dstIndexLocal,
    const LocalTensor<T> &srcLocal, const LocalTensor<int32_t> &srcIndexLocal, const LocalTensor<bool> &finishLocal,
    const int32_t k, const TopkTiling &tilling, const TopKInfo &topKInfo, const bool isLargest = true)
{
    // Only for AI Vector Core.
    if ASCEND_IS_AIC {
        return;
    }

#if ASCENDC_CPU_DEBUG
    TopkInputCheck<T, isInitIndex, topkMode>(k, topKInfo);
#endif

    if constexpr (topkMode == TopKMode::TOPK_NORMAL) {
        TopKNormal<T, isInitIndex, isHasfinish, isReuseSrc>(
            dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, k, tilling, topKInfo, isLargest);
    }
    if constexpr (topkMode == TopKMode::TOPK_NSMALL) {
        TopKNSmall<T, isInitIndex, isHasfinish, isReuseSrc>(
            dstValueLocal, dstIndexLocal, srcLocal, srcIndexLocal, finishLocal, k, tilling, topKInfo, isLargest);
    }
}

#pragma end_pipe
}  // namespace AscendC

#endif

#endif  // LIB_SORT_TOPK_H